-
Notifications
You must be signed in to change notification settings - Fork 169
Allow passing in DistillationConfig directly to setup fn #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Asha Anoosheh <[email protected]>
WalkthroughRenames and reworks the distillation config loader to accept a path, a DistillationConfig instance, or None. Implements branching: None yields a default config, a provided DistillationConfig is used directly, or a YAML path is loaded and parsed. Updates docstrings and replaces the old function. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant C as Caller
participant M as setup_distillation_config
participant Y as YAML Loader
C->>M: setup_distillation_config(config_or_path, student_cfg, teacher_cfg)
alt config_or_path is None
M-->>C: Default DistillationConfig
else config_or_path is DistillationConfig
M-->>C: Provided DistillationConfig (as-is)
else config_or_path is str (path)
M->>Y: Load & parse YAML
Y-->>M: Parsed settings
M-->>C: DistillationConfig from YAML
end
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/distill/plugins/megatron.py (2)
100-109
: Clarify the docstring for the new input modes.The docstring update mentions "the incomplete config itself" which is unclear. Consider revising to explicitly document all three input modes:
Apply this diff to improve the docstring:
- """Read the distillation yaml config file specified by ``args.export_kd_cfg``. + """Setup distillation configuration from various input sources. Args: - config_or_path: Path to user-defined distillation settings yaml file, or the incomplete config itself. - If `None`, uses default logits-only distillation mode for GPT models. + config_or_path: One of: + - `None`: Uses default logits-only distillation mode for GPT models. + - `DistillationConfig`: Uses the provided config instance directly. + - `str`: Path to a YAML file containing distillation settings. student_cfg: Model config for student model. teacher_cfg: Model config for teacher model.
115-118
: Add error handling for file operations and YAML parsing.The code opens a file and parses YAML without error handling. If the path doesn't exist, the file is malformed, or the YAML contains invalid fields for
DistillationConfig
, the function will raise an unhandled exception.Apply this diff to add basic error handling:
else: - with open(config_or_path) as f: - cfg = yaml.safe_load(f) - cfg = DistillationConfig(**cfg) + try: + with open(config_or_path) as f: + cfg_dict = yaml.safe_load(f) + cfg = DistillationConfig(**cfg_dict) + except FileNotFoundError: + raise FileNotFoundError(f"Distillation config file not found: {config_or_path}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in distillation config: {e}") + except TypeError as e: + raise ValueError(f"Invalid distillation config fields: {e}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/distill/plugins/megatron.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/distill/plugins/megatron.py (1)
95-99
: Rename verified: no remaining references toload_distillation_config
. All internal callers are updated.
if config_or_path is None: | ||
logger.warning("Distillation config not provided. Using default.") | ||
cfg = DistillationConfig() | ||
elif isinstance(config_or_path, DistillationConfig): | ||
cfg = config_or_path | ||
else: | ||
with open(config_or_path) as f: | ||
cfg = yaml.safe_load(f) | ||
cfg = DistillationConfig(**cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider the implications of mutating the input DistillationConfig.
When a DistillationConfig
instance is passed in (line 114), the function directly assigns it to cfg
and later mutates it by setting cfg.criterion
(line 146) and cfg.loss_balancer
(line 147). This means the caller's original config object will be modified, which may be unexpected if they intend to reuse the config.
Consider one of these approaches:
- Document the mutation in the docstring:
config_or_path: One of:
- `None`: Uses default logits-only distillation mode for GPT models.
- - `DistillationConfig`: Uses the provided config instance directly.
+ - `DistillationConfig`: Uses the provided config instance directly (will be modified in-place).
- `str`: Path to a YAML file containing distillation settings.
- Create a copy to avoid mutating the input:
elif isinstance(config_or_path, DistillationConfig):
- cfg = config_or_path
+ from copy import deepcopy
+ cfg = deepcopy(config_or_path)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if config_or_path is None: | |
logger.warning("Distillation config not provided. Using default.") | |
cfg = DistillationConfig() | |
elif isinstance(config_or_path, DistillationConfig): | |
cfg = config_or_path | |
else: | |
with open(config_or_path) as f: | |
cfg = yaml.safe_load(f) | |
cfg = DistillationConfig(**cfg) | |
if config_or_path is None: | |
logger.warning("Distillation config not provided. Using default.") | |
cfg = DistillationConfig() | |
elif isinstance(config_or_path, DistillationConfig): | |
from copy import deepcopy | |
cfg = deepcopy(config_or_path) | |
else: | |
with open(config_or_path) as f: | |
cfg = yaml.safe_load(f) | |
cfg = DistillationConfig(**cfg) |
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 110 to 118, the
function assigns a passed-in DistillationConfig instance directly to cfg and
later mutates it, which unintentionally alters the caller's object; to fix this,
create a shallow or deep copy of the incoming DistillationConfig (e.g., via
copy.deepcopy or a provided copy/clone/from_dict constructor) and assign that
copy to cfg before any mutations so the original remains unchanged, or
alternatively document in the function docstring that the input config will be
mutated if that behavior is intended.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #399 +/- ##
=======================================
Coverage 73.79% 73.79%
=======================================
Files 171 171
Lines 17591 17591
=======================================
Hits 12982 12982
Misses 4609 4609 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Asha Anoosheh <[email protected]>
What does this PR do?
Type of change: ? Minor tweak
Overview: ?
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor